Skip to content

Conversation

@AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Jan 8, 2026

What does this PR do?

Type of change: New feature

Overview: Writes a new KLDiv Logits loss which only uses top-k vocab values

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Top-K logit filtering capability for knowledge distillation workflows, enabling selective focus on high-probability tokens.
  • Improvements

    • Enhanced distributed tensor model-parallel operations with improved awareness for gradient computation and reduction.
    • Simplified legacy distributed operation constructs.
  • Tests

    • Introduced comprehensive test coverage for Megatron-based distillation, validating both standard and Top-K filtering variants.

✏️ Tip: You can customize this high-level summary in your review settings.

@AAnoosheh AAnoosheh self-assigned this Jan 8, 2026
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 8, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@AAnoosheh AAnoosheh force-pushed the aanoosheh/topk-kdloss branch from 06d057f to e7d33a7 Compare January 8, 2026 16:48
@codecov
Copy link

codecov bot commented Jan 8, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.62%. Comparing base (b484efb) to head (b128707).
⚠️ Report is 4 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #747   +/-   ##
=======================================
  Coverage   74.62%   74.62%           
=======================================
  Files         192      192           
  Lines       18989    18989           
=======================================
  Hits        14171    14171           
  Misses       4818     4818           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AAnoosheh AAnoosheh marked this pull request as ready for review January 9, 2026 20:02
@AAnoosheh AAnoosheh requested a review from a team as a code owner January 9, 2026 20:02
@AAnoosheh AAnoosheh force-pushed the aanoosheh/topk-kdloss branch from e7a34bf to 3093b8a Compare January 10, 2026 21:30
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
@AAnoosheh AAnoosheh force-pushed the aanoosheh/topk-kdloss branch from 335844f to b128707 Compare January 12, 2026 21:59
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

Added Top-K logit KL-divergence support to Megatron-based distillation by introducing TopKLogitsKLLoss class and logit_kl_topk configuration parameter. The implementation includes distributed tensor model-parallel aware top-K gathering and KL computation, with updated LogitsKLLoss for multi-TP handling.

Changes

Cohort / File(s) Summary
Megatron Distillation - Top-K KL Support
modelopt/torch/distill/plugins/megatron.py
Introduced TopKLogitsKLLoss class implementing top-K logit extraction, global top-K gathering across TP ranks, and KL computation. Added logit_kl_topk field to DistillationConfig. Modified setup_distillation_config to conditionally instantiate TopKLogitsKLLoss vs LogitsKLLoss. Updated LogitsKLLoss to use TP-aware all_reduce for denominators. Refactored logits loss detection to check for "Logits" in key. Removed legacy all_reduce constructs.
Distillation Tests
tests/gpu/torch/distill/plugins/test_distill_megatron.py
New comprehensive test suite validating both standard LogitsKLLoss and TopKLogitsKLLoss under tensor model parallelism. Tests cover Megatron GPT model initialization, distillation configuration setup, forward/backward passes, and gradient flow across distributed GPU environments.

Sequence Diagram(s)

sequenceDiagram
    actor Student as Student Model
    participant LSK as Top-K KL Loss
    participant TP as TP Ranks
    actor Teacher as Teacher Model
    
    Student->>LSK: predictions (logits)
    Teacher->>LSK: targets (logits)
    
    LSK->>LSK: Extract local top-K logits per rank
    LSK->>TP: Gather top-K indices from all TP ranks
    TP->>LSK: Return global top-K across TP
    LSK->>LSK: Collect top-K logits from all ranks
    LSK->>LSK: Compute log probabilities on global top-K
    LSK->>LSK: Compute KL divergence with temperature scaling
    LSK-->>Student: KL loss value
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and directly describes the main feature introduced in the pull request: a Top-K KL Divergence loss implementation for distillation.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
modelopt/torch/distill/plugins/megatron.py (1)

73-79: Consider validating logit_kl_topk > 0 when set.

The __post_init__ validates other parameters but doesn't check that logit_kl_topk is positive when not None. A value of 0 or negative would cause issues downstream in TopKLogitsKLLoss.

Suggested validation
     def __post_init__(self):
         assert len(self.logit_layers) == 2, f"{self.logit_layers=}"
         assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), (
             f"{self.intermediate_layer_pairs=}"
         )
         assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
         assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}"
+        assert self.logit_kl_topk is None or self.logit_kl_topk > 0, f"{self.logit_kl_topk=}"
tests/gpu/torch/distill/plugins/test_distill_megatron.py (1)

38-127: Solid integration test for LogitsKLLoss.

The test covers the essential flow: model creation, distillation setup, forward pass, loss computation, and backward pass.

Consider adding gradient verification for more robust testing:

# After backward pass, verify gradients exist
for name, param in distillation_model.named_parameters():
    if param.requires_grad and param.grad is not None:
        assert param.grad.abs().sum() > 0, f"Zero gradients for {name}"
        break  # At least one param has non-zero gradient
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 727da95 and b128707.

📒 Files selected for processing (2)
  • modelopt/torch/distill/plugins/megatron.py
  • tests/gpu/torch/distill/plugins/test_distill_megatron.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/distill/plugins/test_distill_megatron.py (6)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch/distributed/utils.py (1)
  • spawn_multiprocess_job (51-65)
tests/_test_utils/torch/megatron/models.py (1)
  • get_mcore_gpt_model (125-244)
tests/_test_utils/torch/megatron/utils.py (1)
  • run_mcore_inference_with_dummy_input (122-129)
modelopt/torch/distill/plugins/megatron.py (2)
  • DistillationConfig (52-95)
  • adjust_distillation_model_for_mcore (558-616)
modelopt/torch/distill/distillation_model.py (2)
  • loss_balancer (134-136)
  • compute_kd_loss (237-288)
modelopt/torch/distill/plugins/megatron.py (1)
modelopt/torch/distill/distillation_model.py (1)
  • forward (209-235)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/distill/plugins/megatron.py (4)

129-137: LGTM!

The conditional instantiation logic is clear and correctly passes the appropriate parameters to each loss class.


346-370: LGTM!

The use of dist_nn.functional.all_reduce for computing global softmax denominators correctly preserves gradients through the distributed operation. The comment on lines 347-348 clearly explains the rationale.


373-458: Well-structured Top-K implementation with proper TP handling.

The implementation correctly:

  • Preserves gradients through dist_nn.functional.all_gather
  • Guards against top_k exceeding total vocabulary size
  • Handles edge cases where local vocabulary size is smaller than top_k
  • Avoids unnecessary TP reduction since all ranks compute the same global top-K

The docstring appropriately warns users about memory/communication implications for large K values.


490-493: Reasonable change to accommodate TopKLogitsKLLoss.

The "Logits" in _key check is necessary since TopKLogitsKLLoss doesn't start with "Logits". Be aware this could match unintended keys if future loss classes contain "Logits" anywhere in the name.

tests/gpu/torch/distill/plugins/test_distill_megatron.py (3)

19-33: LGTM!

The skip guard pattern correctly prevents test failures when Megatron or required dependencies are unavailable.


129-217: Test structure is appropriate.

The larger vocab_size=128 correctly provides enough vocabulary entries for meaningful top-k testing with top_k=5.

The duplication with _test_logits_kl_loss could be reduced by extracting common setup into a helper, but the explicit test isolation is acceptable for test readability.


220-237: Consider handling edge cases for device count.

The tests assume torch.cuda.device_count() >= 2 for meaningful tensor parallelism testing. If only one GPU is available, size=1 would run without TP, which may not exercise the distributed code paths.

Consider adding a skip condition:

import pytest

def test_logits_kl_loss():
    """Test LogitsKLLoss with TP parallelism."""
    if torch.cuda.device_count() < 2:
        pytest.skip("Need at least 2 GPUs for TP testing")
    set_seed(SEED)
    spawn_multiprocess_job(
        size=torch.cuda.device_count(),
        job=_test_logits_kl_loss,
        backend="nccl",
    )

Alternatively, verify that the existing helper functions handle this appropriately in the test infrastructure.

# We can't use standard all_reduce function here since the computation
# that follows it isn't identical across TP ranks.
denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True)
denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is torch.distributed.nn.functional.all_reduce same as torch.distributed.all_reduce?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's the premium edition which allows gradient backprop through it

@AAnoosheh AAnoosheh merged commit b813ab5 into main Jan 13, 2026
36 checks passed
@AAnoosheh AAnoosheh deleted the aanoosheh/topk-kdloss branch January 13, 2026 15:09
jingyu-ml pushed a commit that referenced this pull request Jan 14, 2026
## What does this PR do?

**Type of change:** New feature

**Overview:** Writes a new KLDiv Logits loss which only uses top-k vocab
values

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added Top-K logit filtering capability for knowledge distillation
workflows, enabling selective focus on high-probability tokens.

* **Improvements**
* Enhanced distributed tensor model-parallel operations with improved
awareness for gradient computation and reduction.
  * Simplified legacy distributed operation constructs.

* **Tests**
* Introduced comprehensive test coverage for Megatron-based
distillation, validating both standard and Top-K filtering variants.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants